{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd74713-1332-4a0b-a55d-3e0d1662b843",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from torch import optim, nn, utils, Tensor\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms import ToTensor\n",
    "import lightning.pytorch as pl\n",
    "\n",
    "# define any number of nn.Modules (or use your current ones)\n",
    "encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 64))\n",
    "decoder = nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 28 * 28))\n",
    "\n",
    "\n",
    "# define the LightningModule\n",
    "class LitAutoEncoder(pl.LightningModule):\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        # training_step defines the train loop.\n",
    "        # it is independent of forward\n",
    "        x, y = batch\n",
    "        x = x.view(x.size(0), -1)\n",
    "        z = self.encoder(x)\n",
    "        x_hat = self.decoder(z)\n",
    "        loss = nn.functional.mse_loss(x_hat, x)\n",
    "        # Logging to TensorBoard (if installed) by default\n",
    "        self.log(\"train_loss\", loss)\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = optim.Adam(self.parameters(), lr=1e-3)\n",
    "        return optimizer\n",
    "\n",
    "\n",
    "# init the autoencoder\n",
    "autoencoder = LitAutoEncoder(encoder, decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d35dd91a-8295-4404-aa60-de81cf639c0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup data\n",
    "dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())\n",
    "train_loader = utils.data.DataLoader(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96e71aaf-dfa8-43b1-b8a0-bc030693518a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)\n",
    "trainer = pl.Trainer(limit_train_batches=100, max_epochs=50)\n",
    "trainer.fit(model=autoencoder, train_dataloaders=train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bbc796d-5254-4e6b-ab55-7148ba2a85f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# load checkpoint\n",
    "checkpoint = \"./lightning_logs/version_7/checkpoints/epoch=49-step=5000.ckpt\"\n",
    "autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)\n",
    "\n",
    "# choose your trained nn.Module\n",
    "encoder = autoencoder.encoder\n",
    "encoder.eval()\n",
    "\n",
    "# embed 4 fake images!\n",
    "fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)\n",
    "embeddings = encoder(fake_image_batch)\n",
    "print(\"⚡\" * 20, \"\\nPredictions (4 image embeddings):\\n\", embeddings, \"\\n\", \"⚡\" * 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "355f7837-8e5b-4d1a-8ff6-96650d6ea4f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure the autoencoder is on the correct device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "autoencoder = autoencoder.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eb2d3c7-4f97-450d-867c-568d60482e4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from torchvision.utils import make_grid\n",
    "\n",
    "def show_images(images, title=\"Images\"):\n",
    "    \"\"\"Utility function to display a batch of images.\"\"\"\n",
    "    grid_img = make_grid(images, nrow=4)\n",
    "    plt.figure(figsize=(8, 8))\n",
    "    plt.imshow(grid_img.permute(1, 2, 0))\n",
    "    plt.title(title)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "# Load a batch of images from the dataset\n",
    "images, _ = next(iter(train_loader))\n",
    "show_images(images, title=\"Original Images\")\n",
    "\n",
    "# Preprocess the images\n",
    "images = images.view(images.size(0), -1)\n",
    "\n",
    "# Generate image from embeddings\n",
    "embeddings = autoencoder.encoder(images)\n",
    "reconstructed_images =autoencoder.decoder(embeddings).view(-1, 1, 28, 28)\n",
    "show_images(reconstructed_images, title=\"Reconstructed Images\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env_twitter",
   "language": "python",
   "name": "env_twitter"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
